{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Instructions\n",
    "\n",
    "**README** \n",
    "\n",
    "* Download the code from this [link](https://github.com/gmh14/data_efficient_grammar).\n",
    "* Download and unzip the log & checkpoint files from this [link](https://drive.google.com/file/d/12g28WNAgRGzaLtuG6ESg25W-uzlNrpLQ/view). \n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cd data_efficient_grammar\n",
    "!conda create -n DEG_test python=3.6\n",
    "!conda activate DEG_test \n",
    "!conda install scipy pandas numpy scikit-learn\n",
    "!conda install pytorch torchvision torchaudio cpuonly -c pytorch\n",
    "!conda install -c rdkit rdkit\n",
    "\n",
    "!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n",
    "!pip install torch-geometric\n",
    "!pip install torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n",
    "\n",
    "!pip install setproctitle\n",
    "!pip install graphviz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!conda install scipy pandas numpy scikit-learn\n",
    "!conda install pytorch torchvision torchaudio cpuonly -c pytorch\n",
    "!conda install -c rdkit rdkit\n",
    "!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n",
    "!pip install torch-geometric\n",
    "!pip install torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.10.0+cpu.html\n",
    "!pip install setproctitle\n",
    "!pip install graphviz\n",
    "!pip install pickle5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -e retro_star/packages/mlp_retrosyn\n",
    "!pip install -e retro_star/packages/rdchiral"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Play with the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from private.hypergraph import Hypergraph, hg_to_mol\n",
    "from grammar_generation import random_produce\n",
    "\n",
    "\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import Draw\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "import pickle5 as pickle\n",
    "import torch\n",
    "from os import listdir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expr_name_dict = dict()\n",
    "expr_name_dict['polymer_117motif'] = 'grammar-log/log_117motifs'\n",
    "expr_name_dict['iso'] = 'grammar-log/log_iso'\n",
    "expr_name_dict['acrylates'] = 'grammar-log/log_acy'\n",
    "expr_name_dict['chain_extender'] = 'grammar-log/log_ce'\n",
    "\n",
    "expr_names = list(expr_name_dict.keys())\n",
    "generated_mols = dict()\n",
    "for expr_name in expr_names:\n",
    "    print('dealing with {}'.format(expr_name))\n",
    "    ckpt_list = listdir(expr_name_dict[expr_name])\n",
    "    max_R = 0\n",
    "    max_R_ckpt = None\n",
    "    for ckpt in ckpt_list:\n",
    "        if 'grammar' in ckpt:\n",
    "            curr_R = float(ckpt.split('_')[4][:-4])\n",
    "            if curr_R > max_R:\n",
    "                max_R = curr_R\n",
    "                max_R_ckpt = ckpt\n",
    "    print('loading {}'.format(max_R_ckpt))\n",
    "    with open('{}/{}'.format(expr_name_dict[expr_name], max_R_ckpt), 'rb') as fr:\n",
    "        grammar = pickle.load(fr)\n",
    "    for i in range(8):\n",
    "        mol, _ = random_produce(grammar)\n",
    "        if expr_name not in generated_mols.keys():\n",
    "            generated_mols[expr_name] = [mol]\n",
    "        else:\n",
    "            generated_mols[expr_name].append(mol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = 'polymer_117motif' # 'iso', 'acrylates', 'chain_extender'\n",
    "Chem.Draw.MolsToGridImage(generated_mols[exp], molsPerRow=4, subImgSize=(200,200))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Train your own model (w/o optimization)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python main.py --training_data=./datasets/**dataset_path**"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check your model in \"log-num_generated_samples100-_timestamp_\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.1 (main, Dec 23 2022, 09:40:27) [Clang 14.0.0 (clang-1400.0.29.202)]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "1a1af0ee75eeea9e2e1ee996c87e7a2b11a0bebd85af04bb136d915cefc0abce"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
